{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RL Exercise 4 - Online learning with Deep Q Networks\n",
    "\n",
    "**GOAL:** The goal of this exercise is to demonstrate set up a server that simultaneously serves and learns from a policy. Here we're going to use Deep Q Networks (DQN) algorithm, which has better sample efficiency compared to PPO. This lets it learn a policy from fewer experiences.\n",
    "\n",
    "DQN was *the* first deep RL algorithm, and is described in detail in https://arxiv.org/abs/1312.5602.\n",
    "\n",
    "In DQN, instead of training a policy network to directly emit output actions from the observation, we learn a Q function that models the expected outcome of taking certain actions. This model is then used to compute the optimal actions at each step.\n",
    "\n",
    "Unlike policy gradient algorithms such as PPO, DQN can learn from past experiences through *experience replay*. This allows DQN to use experiences multiple times over the course of training, improving its sample efficiency. In this exercise we are going to use a single-process configuration for DQN, but RLlib does provide a distributed variant of DQN: https://ray.readthedocs.io/en/latest/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x\n",
    "\n",
    "![dqn](https://raw.githubusercontent.com/ucbrise/risecamp/risecamp2018/ray/tutorial/rllib_exercises/dqn.png)\n",
    "\n",
    "To understand how to use **RLlib**, see the documentation at http://rllib.io.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running a simple policy server"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE**: Open a new terminal in Jupyter lab using the \"+\" button, and run (you might need to `cd rllib_exercises` first):\n",
    "\n",
    "`$ python serving/simple_policy_server.py --action-size=2 --observation-size=4 --run=DQN --checkpoint-file=cartpole`\n",
    "\n",
    "Take a look at the file to see how the policy server is implemented. In policy serving, there is no step() function that RLlib can run.\n",
    "Intuitively, this is because there is no simulator -- the policy must interact with the real world and you can't call step() on that.\n",
    "To support this use case, RLlib provides a special [ServingEnv environment type](https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/serving_env.py) in which the environment executes in its own thread of control. When a decision needs to be made, the policy is queried via `self.get_action(obs)`, and rewards reported via `self.log_returns()`.\n",
    "\n",
    "In `simple_policy_server.py`, you'll notice that the env makes use of the built in [PolicyServer class](https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/policy_server.py). All this class does is handle incoming HTTP requests and does the following for each episode:\n",
    "1. Call `self.start_episode()` to get a new episode_id.\n",
    "2. Call `self.get_action(episode_id, obs)` or `self.log_action(episode_id, obs, action)`\n",
    "3. Call `self.log_returns(episode_id, reward)`\n",
    "4. Call `self.end_episode(episode_id, obs)`\n",
    "\n",
    "Note that PolicyServer is a very basic server, however ServingEnv can work with any kind of Python server implementation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connect to the policy server you started and initialize an environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ray.rllib.utils.policy_client import PolicyClient\n",
    "client = PolicyClient(\"http://localhost:8900\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start by solving a problem you've already seen before: CartPole. It's important to note however that here RLlib is not interacting directly with the CartPole-v0 environment instance at all. All policy interactions and learning will be through the HTTP policy client:\n",
    "\n",
    "![client](https://raw.githubusercontent.com/ucbrise/risecamp/risecamp2018/ray/tutorial/rllib_exercises/client.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "env = gym.make(\"CartPole-v0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE**: Implement `run_one_episode(env)` below to rollout a CartPole episode using the client and env created above. You'll find the [policy client documentation](https://ray.readthedocs.io/en/latest/rllib-package-ref.html#ray.rllib.utils.PolicyClient) to be helpful here.\n",
    "\n",
    "You should see a mean episode reward of about 20, which corresponds to that of an untrained policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_one_episode(env):\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    episode_id = client.start_episode() # We start by requesting a new episode id from the server\n",
    "    total_reward = 0\n",
    "    while not done:\n",
    "        action = ??? # TODO call get_action to get the action for the current observation\n",
    "        obs, rew, done, info = env.step(action)\n",
    "        ??? # TODO tell the server about the recent returns of the action\n",
    "        if done:\n",
    "            ??? # TODO tell the server the episode ended\n",
    "        total_reward += rew\n",
    "    print(\"Episode reward\", total_reward)\n",
    "\n",
    "run_one_episode(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE**: Run episodes until the server reaches a peak reward of ~200 (this might take a few hundred episodes). In another terminal, open TensorBoard to monitor the learning curve of the policy.\n",
    "\n",
    "If you run into problems, you can restart the server with Ctrl-C, and it will load its last saved state from the file specifed by `--checkpoint-file`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    run_one_episode(env)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Serving a Pong AI\n",
    "\n",
    "Next, we'll use the same policy server to train a Pong AI. Interrupt your cartpole server with Ctrl-C, and use the following configuration to start a policy server suitable for learning a Pong policy. Note here that the observation size is only slightly increased from 4 to 8. This is because the policy is going to operate over a minimal state description of the Pong game, instead of on raw images (that would take much longer to train and would require a GPU).\n",
    "\n",
    "**EXERCISE**: Set up the pong policy server, replacing the cartpole server.\n",
    "\n",
    "`$ python serving/simple_policy_server.py --action-size=3 --observation-size=8 --run=DQN --checkpoint-file=pong`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll set up a web server so that you can play Pong against the policy in your browser.\n",
    "\n",
    "**EXERCISE**: Set up the web server in another terminal and play a game of Pong against the AI.\n",
    "\n",
    "`$ python serving/pong_web_server.py`\n",
    "\n",
    "The web server will listen on port 8900 for HTTP connections. It will internally connect to the policy server already running on port 3000.\n",
    "\n",
    "#### **In a new browser tab, you should go to a URL like `https://hub.mybinder.org/user/ray-project-tutorial-YOUR_NOTEBOOK_ID/proxy/8900/serving/javascript-pong/static/index.html`**\n",
    "\n",
    "![web](https://raw.githubusercontent.com/ucbrise/risecamp/risecamp2018/ray/tutorial/rllib_exercises/web.png)\n",
    "\n",
    "As you play the game of Pong, notice the policy decisions being made by the policy server. It will look something like this:\n",
    "\n",
    "![log](https://raw.githubusercontent.com/ucbrise/risecamp/risecamp2018/ray/tutorial/rllib_exercises/log.png)\n",
    "\n",
    "Also note that the policy is probably not very good. If you play enough games, it will eventually get better, but we have a faster way of fixing this..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE**: To further improve your policy, run the `python serving/do_rollouts.py` script to automatically run many more games against the live policy. This will eventually train the agent to play competitively, after 60k or so steps. Try playing against the server while these rollouts are happening in the background! Do you notice the policy improving? The learning curve in TensorBoard will look something like this:\n",
    "\n",
    "![learning](https://raw.githubusercontent.com/ucbrise/risecamp/risecamp2018/ray/tutorial/rllib_exercises/learning.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optional exercise: Learning from logs data\n",
    "\n",
    "We've included two datasets, `data_small.gz` and `data_large.gz` that you can use to help your training. The following cell will load the small dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import gzip\n",
    "episodes = []  # list of episodes\n",
    "for line in gzip.open(\"serving/data_small.gz\"):\n",
    "    episodes.append(json.loads(line.decode(\"utf-8\")))\n",
    "for i in range(10):\n",
    "    print(episodes[0][i])  # each episode is a list of these dicts representing steps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**EXERCISE**: Using the `client.log_action(episode_id, obs, action)` and `client.log_returns(episode_id, reward)` calls, replay the log data to the policy server. Try both the small and large dataset.\n",
    "Loading the large dataset can take a while due to the naive approach of sending steps one by one. While it is in progress, try playing the Pong AI and see if you can see any improvement (you might want to change the --checkpoint-file flag to start over from a fresh policy).\n",
    "\n",
    "Note that in a production setting, you wouldn't want to send logs to the server over the network like this. Instead, you can write an environment that reads from log shards in parallel and use RLlib to distribute the computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "steps = 0  # step counter\n",
    "for episode in episodes:\n",
    "    episode_id = client.start_episode()\n",
    "    print(\"Replaying episode\", episode_id, \"steps total\", steps)\n",
    "    for step in episode:\n",
    "        ??? # TODO: log the action\n",
    "        ??? # TODO: log the returns\n",
    "        steps += 1\n",
    "    client.end_episode(episode_id, [0]*8)  # assume last observation is all zeros"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## More optional exercises\n",
    "\n",
    "**EXERCISE**: Try the above exercises but starting the server with `--run=PG` to use a policy gradient algorithm instead of DQN. How does this compare in learning performance? What about in ability to leverage offline data?\n",
    "\n",
    "**EXERCISE**: Come up with a simple \"toy environment\" with a small action and observation space. Use one of the RLlib [algorithms](https://ray.readthedocs.io/en/latest/rllib.html#algorithms) to solve that environment.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
